Skip to content

[JAX] Add an MoE Block (Layer) that compound router, permutation, groupedGEMM and communication#2912

Open
tdophung wants to merge 14 commits into
NVIDIA:mainfrom
tdophung:teddy/moe_block
Open

[JAX] Add an MoE Block (Layer) that compound router, permutation, groupedGEMM and communication#2912
tdophung wants to merge 14 commits into
NVIDIA:mainfrom
tdophung:teddy/moe_block

Conversation

@tdophung
Copy link
Copy Markdown
Collaborator

@tdophung tdophung commented Apr 21, 2026

Description

Most of MoE building blocks integration work has been deeply coupled with Maxtext development. Now creating this MoE block to isolate the work from Maxtext and provide more room for experimentation. MoEBlock is a self-contained Flax-Linen module that wires together TE's fused router, pluggable token-dispatch backends (pure-JAX argsort or Triton sort_chunks_by_index), grouped_dense-based expert FFN, and ragged-all-to-all (A2Av) expert parallelism via shard_map

This first iteration will start with ring-of-experts EP, sharding on batch dimention for FSDP, CUBLASLt groupedGEMM and 2 permutation backend: pure JAX or Triton kernels. The block also exposes pluggable knobs for: weight layout (wi_kernel_axes/ wo_kernel_axes), permutation backend, A2A vs no-EP (single GPU) path, data-parallelism axes for true FSDP (batch sharded across (ep, fsdp) simultaneously), top-K with optional grouped/sigmoid scoring (for DSv3 workload), and optional auxiliary load-balancing loss.

Fixes #2895

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • New transformer_engine/jax/flax/moe.py -- MoEBlock Linen module:
    gate -> fused topk -> global permute -> A2A EP shard_map (ragged_a2a fwd, local permute, 3x grouped GEMM SwiGLU FFN, local unpermute, ragged_a2a rev) -> global combine.
  • Extended transformer_engine/jax/permutation.py with A2A param helpers (compute_ragged_all_to_all_params, compute_reverse_ragged_all_to_all_params, local_permute_after_a2a, local_unpermute_before_a2a) and the pure-JAX unfused_token_dispatch / unfused_token_combine paths
    with custom VJPs.
  • tests/jax/test_moe_block.py -- single-device shape, backward,
    cross-backend equivalence, aux-loss, group-topk, JIT determinism.
  • tests/jax/test_distributed_moe_block.py -- EP=2 x FSDP=2 mesh test using the canonical Flax-Linen sharded-init pattern (eval_shape -> get_partition_spec -> logical_to_mesh_sharding -> jit(init, out_shardings=...)) and data_parallelism_axes=("fsdp",) to exercise true FSDP (batch sharded across both axes).

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@tdophung tdophung marked this pull request as ready for review May 5, 2026 21:47
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 5, 2026

Greptile Summary

This PR introduces _MoEBlock, an experimental self-contained Flax-Linen Mixture-of-Experts block for TransformerEngine JAX. It wires together the fused router, pluggable token-dispatch backends (pure-JAX argsort or Triton), cuBLASLt grouped_dense-based expert FFN, and ragged-all-to-all expert parallelism via shard_map.

  • transformer_engine/jax/flax/moe.py: New _MoEBlock module with gate, router, global permute, A2A EP path, expert FFN, and global combine stages; two top-level forward variants (no-EP and shard_map-wrapped A2A-EP).
  • transformer_engine/jax/permutation.py: Adds pure-JAX argsort-based dispatch/combine with custom VJPs, PureJaxPermState, routing_map_to_selected_experts, and ragged-all-to-all EP helpers.
  • transformer_engine/jax/sharding.py and gemm.py: Minor additions — ep_resource field on MeshResource, get_active_resource_axis helper, and removal of @cache on _should_enforce_v2_grouped_gemm to support monkeypatch.setenv in tests.

Confidence Score: 4/5

Safe to merge after addressing the missing ep_axis exclusion guard in the data_parallelism_axes validation loop.

The A2A-EP forward correctly addresses the recv_buffer_rows alignment fix. One gap remains: if a caller passes the EP axis name in data_parallelism_axes, the batch PartitionSpec gets a duplicate axis and dp_size is double-counted, producing an undersized ragged_all_to_all receive buffer with no useful error message.

transformer_engine/jax/flax/moe.py — specifically the data_parallelism_axes validation block in _forward_a2a_ep.

Important Files Changed

Filename Overview
transformer_engine/jax/flax/moe.py New 1174-line _MoEBlock Linen module; contains a missing validation allowing ep_axis to appear in data_parallelism_axes, which produces a duplicate-axis PartitionSpec and an undersized recv buffer.
transformer_engine/jax/permutation.py Adds pure-JAX dispatch/combine with custom VJPs and ragged-A2A helpers; logic is correct, recv_buffer_rows alignment fix is in place.
transformer_engine/jax/sharding.py Adds ep_resource to MeshResource and get_active_resource_axis helper; consistent with existing axis-resolution patterns.
transformer_engine/jax/cpp_extensions/gemm.py Removes @cache from _should_enforce_v2_grouped_gemm so monkeypatch.setenv works in tests; negligible performance impact.
tests/jax/test_moe_block.py Comprehensive single-device tests covering shape, backward, cross-backend equivalence, aux loss, group-topk, align_size, and JIT determinism.
tests/jax/test_distributed_moe_block.py EP=2 x FSDP=2 distributed test using canonical Flax-Linen sharded-init pattern; validates output, loss, aux_loss, and per-parameter gradients.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant _MoEBlock
    participant Router
    participant GlobalPermute
    participant A2A as ragged_all_to_all (EP)
    participant LocalPerm as local_permute_after_a2a
    participant ExpertFFN as _expert_ffn (grouped_dense x3)
    participant GlobalCombine

    Caller->>_MoEBlock: inputs [B, S, H]
    _MoEBlock->>Router: "gate_logits -> fused_topk_with_score_function"
    Router-->>_MoEBlock: sparse_probs, routing_map
    _MoEBlock->>GlobalPermute: _global_permute (pure_jax or triton)
    GlobalPermute-->>_MoEBlock: sorted_inputs, group_sizes [E]

    alt No-EP path
        _MoEBlock->>ExpertFFN: "sorted_inputs, group_sizes, n_groups=E"
        ExpertFFN-->>_MoEBlock: expert_outputs
    else A2A-EP path via shard_map
        _MoEBlock->>A2A: all_gather(group_sizes)
        A2A->>A2A: forward ragged_all_to_all over ep axis
        A2A->>LocalPerm: reorder recv buffer
        LocalPerm-->>A2A: sorted_x, local_group_sizes
        A2A->>ExpertFFN: sorted_x, local_group_sizes
        ExpertFFN-->>A2A: expert_outputs
        A2A->>LocalPerm: local_unpermute_before_a2a
        A2A->>A2A: reverse ragged_all_to_all
        A2A-->>_MoEBlock: y_back
    end

    _MoEBlock->>GlobalCombine: _global_combine
    GlobalCombine-->>_MoEBlock: output [B, S, H]
    _MoEBlock-->>Caller: output [B, S, H], aux_loss
Loading

Reviews (6): Last reviewed commit: "change naming and add message for experi..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/permutation.py
Comment thread transformer_engine/jax/flax/moe.py
tdophung added 6 commits May 5, 2026 16:35
Signed-off-by: tdophung <tdophung@nvidia.com>
…ody single GPU vs. multi GPU

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
…e and single device initial params in the MoEBlock. Tests should pass now

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung tdophung force-pushed the teddy/moe_block branch from 8a838f3 to 6aeb491 Compare May 5, 2026 23:44
pre-commit-ci Bot and others added 2 commits May 5, 2026 23:45
Signed-off-by: tdophung <tdophung@nvidia.com>
Comment on lines +427 to +457
def _compute_aux_loss(
self,
logits_2d: jnp.ndarray,
) -> Optional[jnp.ndarray]:
"""Compute the MoE auxiliary load-balancing loss.

The score-for-aux kernel has no data dependency on the main
routing kernel, so XLA can overlap them on the GPU.

``logits_2d`` should be the *full* logits tensor over the global
token batch -- under EP the caller is responsible for
:func:`jax.lax.all_gather` ing the logits before calling this so
the aux_loss formula
``loss = (E * coeff / (k * T^2)) * sum_i(sum_t(probs[t,i]) * tokens[i])``
sees the global ``T`` and the global ``tokens_per_expert``.
"""
if self.aux_loss_coeff <= 0.0:
return None
aux_scores, aux_routing_map = fused_topk_with_score_function(
logits_2d.astype(jnp.float32),
topk=self.num_experts_per_tok,
score_function=self.score_function,
compute_aux_scores=True,
)
aux_tokens_per_expert = jnp.sum(aux_routing_map.astype(jnp.int32), axis=0)
return fused_moe_aux_loss(
aux_scores.astype(jnp.float32),
aux_tokens_per_expert,
topk=self.num_experts_per_tok,
coeff=self.aux_loss_coeff,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Aux loss tokens_per_expert is inconsistent with actual grouped-topk routing

When num_groups > 0 and group_topk > 0 (DeepSeek-style routing), fused_topk_with_score_function(..., compute_aux_scores=True) intentionally ignores those parameters and runs a clean standard top-k. The returned aux_routing_map therefore reflects different expert selections than the actual routing_map produced by _route_topk, causing aux_tokens_per_expert = sum(aux_routing_map, axis=0) to count a different token–expert distribution. Any user who combines num_groups > 0 + group_topk > 0 + aux_loss_coeff > 0 silently trains with a wrong auxiliary objective. The existing test_group_topk_deepseek test does not catch this because it leaves aux_loss_coeff at its default of 0.0.

Comment thread tests/jax/test_distributed_moe_block.py Outdated
Comment thread tests/jax/test_moe_block.py Outdated
Comment thread tests/jax/test_moe_block.py Outdated
Comment thread tests/jax/test_moe_block.py Outdated
Comment thread transformer_engine/jax/permutation.py Outdated
Comment thread transformer_engine/jax/flax/moe.py Outdated
Comment thread transformer_engine/jax/flax/moe.py Outdated
Comment thread transformer_engine/jax/flax/moe.py Outdated
Comment thread transformer_engine/jax/flax/moe.py
Comment thread transformer_engine/jax/flax/moe.py Outdated
nvjax and others added 2 commits May 7, 2026 15:18
…int in C++ files, make FP8 works. Tested with current scaling

Signed-off-by: JAX Toolbox <jax@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 7, 2026

Want your agent to iterate on Greptile's feedback? Try greploops.

Comment thread transformer_engine/common/util/multi_stream.cpp Outdated
Comment thread transformer_engine/jax/csrc/extensions/gemm.cpp Outdated
Comment thread tests/jax/test_moe_block.py Outdated
Comment thread transformer_engine/jax/csrc/extensions/gemm.cpp Outdated
Comment thread transformer_engine/jax/flax/moe.py
Comment thread tests/jax/test_moe_block.py Outdated
Comment thread transformer_engine/common/util/multi_stream.cpp Outdated
… grad tol to 5e-2, move arch/align_size docs into MoEBlock class

Signed-off-by: tdophung <tdophung@nvidia.com>
Comment on lines +909 to +914
batch_divisor = num_ep * dp_size
if global_batch_size % batch_divisor != 0:
raise ValueError(
f"batch={global_batch_size} not divisible by prod(data_parallelism_axes)={dp_size}"
)
recv_buffer_rows = (global_batch_size // dp_size) * sequence_length * topk
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Receive buffer undersized when align_size > 0 + EP are combined

recv_buffer_rows is computed assuming unpadded token counts, but when align_size > 0 the per-expert group_sizes are the aligned counts, so the send_sizes in compute_ragged_all_to_all_params include padding tokens. The worst-case receive per shard is num_ep * ((B/(num_ep*dp_size))*S*K + num_experts_per_shard*(align_size-1)), which exceeds the current recv_buffer_rows = (B/dp_size)*S*K by up to num_experts*(align_size-1) rows. ragged_all_to_all writing beyond the buffer produces incorrect results or a crash. The correct worst-case size is:

recv_buffer_rows = (global_batch_size // dp_size) * sequence_length * topk + num_experts * (self.align_size - 1 if self.align_size > 0 else 0)

This combination (EP + align_size > 0) is not exercised by the current distributed test (which defaults to align_size=0), so the bug is latent.

Copy link
Copy Markdown
Collaborator

@phu0ngng phu0ngng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should go with exposing GroupMLP VJP first before the MoE module to enable future possible fusions.

tdophung added 3 commits May 12, 2026 15:53
…ing None as group_topk, align_size rename,

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Comment on lines +923 to +928
for ax in self.data_parallelism_axes:
if ax not in mesh.shape:
raise ValueError(
f"data_parallelism_axes contains {ax!r} but mesh has"
f" axes {tuple(mesh.shape.keys())}"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 The validation loop checks that every axis in data_parallelism_axes exists in the mesh but does not check that the axis differs from ep_axis. If a caller passes data_parallelism_axes=("ep",) when ep_axis="ep", batch_pspec_axis becomes ("ep", "ep") — a duplicate-axis PartitionSpec that JAX rejects with a cryptic error. Independently, dp_size accumulates mesh.shape["ep"] a second time, so recv_buffer_rows is undersized by a factor of num_ep and batch_divisor becomes num_ep², both causing wrong runtime behaviour before JAX ever sees the bad spec.

Suggested change
for ax in self.data_parallelism_axes:
if ax not in mesh.shape:
raise ValueError(
f"data_parallelism_axes contains {ax!r} but mesh has"
f" axes {tuple(mesh.shape.keys())}"
)
for ax in self.data_parallelism_axes:
if ax not in mesh.shape:
raise ValueError(
f"data_parallelism_axes contains {ax!r} but mesh has"
f" axes {tuple(mesh.shape.keys())}"
)
if ax == ep_axis:
raise ValueError(
f"data_parallelism_axes contains {ax!r}, which is the same as the"
f" EP axis {ep_axis!r}. The EP axis is already included in the batch"
" sharding spec; listing it again produces a duplicate-axis"
" PartitionSpec and an undersized ragged_all_to_all receive buffer."
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[JAX] Create initial MoE Block

4 participants